package org.smartboot.pass;

import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.smartboot.http.client.Body;
import org.smartboot.http.client.HttpClient;
import org.smartboot.http.client.HttpRest;
import org.smartboot.http.client.ResponseHandler;
import org.smartboot.http.client.impl.Response;
import org.smartboot.http.common.enums.HeaderNameEnum;
import org.smartboot.http.common.enums.HttpMethodEnum;
import org.smartboot.http.common.enums.HttpStatus;
import org.smartboot.http.common.enums.HttpTypeEnum;
import org.smartboot.http.common.utils.StringUtils;
import org.smartboot.http.server.HttpRequest;
import org.smartboot.http.server.HttpResponse;
import org.smartboot.http.server.HttpServerHandler;
import org.smartboot.http.server.impl.Request;
import org.smartboot.socket.MessageProcessor;
import org.smartboot.socket.StateMachineEnum;
import org.smartboot.socket.transport.AioQuickClient;
import org.smartboot.socket.transport.AioSession;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;

/**
 * @author 三刀（zhengjunweimail@163.com）
 * @version V1.0 , 2021/7/14
 */
public class ProxyConnectionServerHandler extends HttpServerHandler {
    private static final Logger LOGGER = LoggerFactory.getLogger(ProxyConnectionServerHandler.class);
    private static final List<String> BLACK_DOMAIN = Arrays.asList("content-autofill.googleapis.com", "www.google.com", "play.google.com");
    private final Config config;
    private final boolean proxyEnabled;

    /**
     * Http 明文代理
     */
    private HttpRest httpRest;
    private HttpClient httpClient;
    private int httpRequestContentLength;

    private boolean closed = false;

    private AioQuickClient proxyClient;
    private HttpClient proxyHttpClient;
    private HttpRest proxyHttpRest;


    public ProxyConnectionServerHandler(Config config) {
        this.config = config;
        this.proxyEnabled = StringUtils.isNotBlank(config.getProxyHost()) && config.getProxyPort() > 0;
    }

    @Override
    public void onHeaderComplete(Request request) throws IOException {
        super.onHeaderComplete(request);
        logHeader(request);
        if (closed) {
            return;
        }
        // Proxy --> Client
        HttpResponse clientResponse = request.getRequestType() == HttpTypeEnum.WEBSOCKET ?
                request.newWebsocketRequest().getResponse() : request.newHttpRequest().getResponse();

        //授权校验
        if (!basicCheck(request, clientResponse)) {
            return;
        }

        String host = request.getHeader(HeaderNameEnum.HOST.getName());
        String[] array = host.split(":");
        String ip = array[0];

        //黑名单域名，直接拦截
        if (BLACK_DOMAIN.contains(ip)) {
            LOGGER.warn("black domain {}", ip);
            clientResponse.close();
            return;
        }

        if (HttpMethodEnum.CONNECT.getMethod().equals(request.getMethod())) {
            acceptHttpsProxy(request, clientResponse, ip, array.length == 2 ? NumberUtils.toInt(array[1]) : 443);
        } else {
            acceptHttpProxy(request, clientResponse, ip, array.length == 2 ? NumberUtils.toInt(array[1]) : 80);
        }
    }

    /**
     * 建立 Http 请求的代理连接
     */
    private void acceptHttpProxy(Request request, HttpResponse clientResponse, String ip, int port) {
        httpRequestContentLength = request.getContentLength();

        // 建立 TCP 连接
        httpClient = new HttpClient(ip, port);
        if (proxyEnabled) {
            httpClient.configuration().proxy(config.getProxyHost(), config.getProxyPort(), config.getProxyUserName(), config.getProxyPassword());
        }
        httpClient.setAsynchronousChannelGroup(config.getProxyChanelGroup());
        httpClient.configuration().connectTimeout(3000).writeBufferPool(config.getBufferPagePool()).readBufferPool(config.getHttpClientReadBufferPagePool());
        // 发送 http 请求
        httpRest = httpClient.rest(getUir(request)).setMethod(request.getMethod());
        request.getHeaderNames().stream()
                //会自动填充host
                .filter(headerName -> !(HeaderNameEnum.HOST.getName().equalsIgnoreCase(headerName)
                        || HeaderNameEnum.USER_AGENT.getName().equalsIgnoreCase(headerName)))
                .forEach(headerName -> httpRest.header().add(headerName, request.getHeader(headerName)));
        //重写 user_agent
        httpRest.header().add(HeaderNameEnum.USER_AGENT.getName(), request.getHeader(HeaderNameEnum.USER_AGENT.getName()) + ";smart-http")
                .done()
                .onResponse(new RemoteHttpProxyResponseHandler(clientResponse))
                .onSuccess(proxyResponse -> {
                    try {
                        clientResponse.setHttpStatus(proxyResponse.getStatus(), proxyResponse.getReasonPhrase());
                        proxyResponse.getHeaderNames()
                                .forEach(headerName -> clientResponse.setHeader(headerName, proxyResponse.getHeader(headerName)));
                        clientResponse.getOutputStream().flush();
                        request.getAioSession().close(false);
                    } catch (IOException e) {
                        LOGGER.error("", e);
                    }
                })
                .onFailure(throwable -> {
                    LOGGER.error("", throwable);
                    clientResponse.setHttpStatus(HttpStatus.INTERNAL_SERVER_ERROR.value(), throwable.getMessage());
                    clientResponse.close();
                });
    }

    /**
     * 建立 Https 请求的代理连接
     */
    private void acceptHttpsProxy(Request request, HttpResponse clientResponse, String ip, int port) {
        if (proxyEnabled) {
            connectToProxyServer(request, clientResponse, ip, port);
            return;
        }
        try {
            clientResponse.setHttpStatus(HttpStatus.OK.value(), "Connection Established");
            clientResponse.getOutputStream().flush();
            if (clientResponse.getHttpStatus() == HttpStatus.OK.value()) {
                proxyClient = createProxyClient(ip, port, request);
            }
        } catch (Exception e) {
            LOGGER.error("connect to {}:{}", ip, port, e);
            clientResponse.close();
        }
    }

    /**
     * 打印 Http Header信息
     */
    private void logHeader(Request request) {
        StringBuilder stringBuilder = new StringBuilder("receive client request\r\n");
        stringBuilder.append(request.getMethod()).append(' ').append(getUir(request)).append("\r\n");
        request.getHeaderNames().forEach(headerName -> stringBuilder.append(headerName).append(" : ").append(request.getHeader(headerName)).append("\r\n"));
        stringBuilder.append("\r\n");
        LOGGER.info(stringBuilder.toString());
    }

    private String getUir(Request request) {
        if (StringUtils.isBlank(request.getQueryString())) {
            return request.getUri();
        } else {
            return request.getUri() + "?" + request.getQueryString();
        }
    }

    /**
     * Https 二次代理
     */
    private void connectToProxyServer(Request request, HttpResponse clientResponse, String ip, int port) {
        proxyHttpClient = new HttpClient(ip, port);
        proxyHttpClient.configuration().connectTimeout(3000)
                .proxy(config.getProxyHost(), config.getProxyPort(), config.getProxyUserName(), config.getProxyPassword())
                .writeBufferPool(config.getBufferPagePool()).readBufferPool(config.getHttpClientReadBufferPagePool());
        proxyHttpClient.setAsynchronousChannelGroup(config.getProxyChanelGroup());
        proxyHttpRest = proxyHttpClient.rest(getUir(request)).setMethod(HttpMethodEnum.CONNECT.getMethod())
                .onResponse(new ResponseHandler() {
                    @Override
                    public void onHeaderComplete(Response response) {
                        if (response.getStatus() == HttpStatus.PROXY_AUTHENTICATION_REQUIRED.value()) {
                            clientResponse.setHttpStatus(HttpStatus.INTERNAL_SERVER_ERROR);
                            try {
                                clientResponse.getOutputStream().write("proxy server config error".getBytes(StandardCharsets.UTF_8));
                            } catch (IOException e) {
                                LOGGER.error("", e);
                            }
                            clientResponse.close();
                            return;
                        }
                        clientResponse.setHttpStatus(response.getStatus(), response.getReasonPhrase());
                        try {
                            clientResponse.getOutputStream().flush();
                        } catch (IOException e) {
                            LOGGER.error("", e);
                        }
                    }

                    @Override
                    public boolean onBodyStream(ByteBuffer buffer, Response response) {
                        byte[] bytes = new byte[buffer.remaining()];
                        buffer.get(bytes);
                        try {
                            clientResponse.getOutputStream().write(bytes);
                            clientResponse.getOutputStream().flush();
                        } catch (IOException e) {
                            LOGGER.error("", e);
                        }
                        return false;
                    }
                })
                .onSuccess(response -> {
                }).onFailure(throwable -> LOGGER.error("", throwable));
        proxyHttpRest.done();
    }

    private boolean basicCheck(Request request, HttpResponse response) {
        if (config.getBasic() == null || StringUtils.equals(request.getHeader(HeaderNameEnum.PROXY_AUTHORIZATION.getName()), config.getBasic())) {
            return true;
        }
        response.setHeader(HeaderNameEnum.PROXY_AUTHENTICATE.getName(), "Basic realm=\"Access to internal site\"");
        response.setHttpStatus(HttpStatus.PROXY_AUTHENTICATION_REQUIRED);
        try {
            response.getOutputStream().flush();
            response.close();
        } catch (IOException e) {
            LOGGER.error("", e);
        }
        return false;
    }

    /**
     * 将客户端提交的 Http Body  内容转发至远程服务器
     */
    @Override
    public boolean onBodyStream(ByteBuffer buffer, Request request) {
        if (closed) {
            LOGGER.warn("body closed");
            return true;
        }
        if (HttpMethodEnum.CONNECT.getMethod().equals(request.getMethod())) {
            sendHttpsBodyToRemoteServer(buffer);
        } else {
            sendHttpBodyToRemoteServer(buffer, request);
        }
        return false;
    }

    @Override
    public void handle(HttpRequest request, HttpResponse response) throws IOException {
        System.out.println("aaaaa");
    }

    private void sendHttpsBodyToRemoteServer(ByteBuffer buffer) {
        byte[] bytes = new byte[buffer.remaining()];
        buffer.get(bytes);
        //https 二次代理
        if (proxyEnabled) {
            LOGGER.info("onBodyStream https proxy");
            proxyHttpRest.body().write(bytes, 0, bytes.length).flush();
        } else {
            LOGGER.info("onBodyStream https");
            try {
                AioSession session = proxyClient.getSession();
                session.writeBuffer().write(bytes);
                session.writeBuffer().flush();
            } catch (IOException e) {
                LOGGER.error("", e);
                //todo
//                onClose();
            }
        }
    }

    private void sendHttpBodyToRemoteServer(ByteBuffer buffer, Request request) {
        LOGGER.info("onBodyStream http {}", request.getUri());
        if (httpRequestContentLength >= 0) {
            sendToRemoteServerWithContentLength(buffer);
        } else {
            sendToRemoteServerWithoutLength(buffer, request);
        }
    }

    /**
     * 客户端提交的Http请求不包含 Content-Length，可能是GET请求
     */
    private void sendToRemoteServerWithoutLength(ByteBuffer buffer, Request request) {
        int pos = buffer.position();
        boolean finish = super.onBodyStream(buffer, request);
        Body bodyStream = httpRest.body()
                .write(buffer.array(), pos + buffer.arrayOffset(), buffer.position() - pos);
        if ((buffer.position() - pos) > 0) {
            System.out.println("aaaaaaaaaaaaaa");
        }
        if (finish) {
            LOGGER.info("flush_2");
            bodyStream.flush();
        }
    }

    /**
     * 客户端提交的请求包含 Content-Length
     */
    private void sendToRemoteServerWithContentLength(ByteBuffer buffer) {
        int readSize = Math.min(httpRequestContentLength, buffer.remaining());
        Body bodyStream = httpRest.body()
                .write(buffer.array(), buffer.position() + buffer.arrayOffset(), readSize);
        buffer.position(buffer.position() + readSize);
        httpRequestContentLength -= readSize;
        if (httpRequestContentLength == 0) {
            LOGGER.info("flush");
            bodyStream.flush();
        }
    }

    @Override
    public void onClose(Request request) {
        super.onClose(request);
        closed = true;
        if (proxyClient != null) {
            try {
                proxyClient.shutdownNow();
            } catch (Exception ignored) {

            }
            proxyClient = null;
        }
        if (httpClient != null) {
            try {
                httpClient.close();
            } catch (Exception ignored) {

            }
            httpClient = null;
        }
        if (proxyHttpClient != null) {
            try {
                proxyHttpClient.close();
            } catch (Exception ignored) {

            }
            proxyHttpClient = null;
        }
    }

    /**
     * 生成代理客户端
     */
    private AioQuickClient createProxyClient(String host, int port, Request request) throws IOException {
        AioQuickClient client = new AioQuickClient(host, port, (readBuffer, session) -> {
            byte[] bytes = new byte[readBuffer.remaining()];
            LOGGER.info("uri: {} read data size: {}", getUir(request), bytes.length);
            readBuffer.get(bytes);
            return bytes;
        }, new MessageProcessor<byte[]>() {
            @Override
            public void process(AioSession session, byte[] msg) {
                try {
                    request.getAioSession().writeBuffer().write(msg);
                    request.getAioSession().writeBuffer().flush();
                } catch (IOException e) {
                    LOGGER.error("", e);
                }
            }

            @Override
            public void stateEvent(AioSession session, StateMachineEnum stateMachineEnum, Throwable throwable) {
                if (throwable != null &&
                        (stateMachineEnum == StateMachineEnum.PROCESS_EXCEPTION || stateMachineEnum == StateMachineEnum.DECODE_EXCEPTION)) {
                    LOGGER.error("stateEvent", throwable);
                }
            }
        });
        client.setBufferPagePool(config.getBufferPagePool());
        client.setReadBufferSize(1024 * 1024).connectTimeout(10000);
        client.start(config.getProxyChanelGroup());
        return client;
    }
}
